{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "authorship_tag": "ABX9TyOaHR81EoRiuDjhTvY9PZ+l",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/Baijiong-Lin/LoRA-Torch/blob/main/examples/mergedlinear.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3YYQD-Le8iaB",
        "outputId": "1d5241cf-ce37-4abd-e282-170cba75f654"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting git+https://github.com/microsoft/LoRA\n",
            "  Cloning https://github.com/microsoft/LoRA to /tmp/pip-req-build-b6kloy6q\n",
            "  Running command git clone --filter=blob:none --quiet https://github.com/microsoft/LoRA /tmp/pip-req-build-b6kloy6q\n",
            "  Resolved https://github.com/microsoft/LoRA to commit 998cfe4d351f4d6b4a47f0921dec2397aa0b9dfe\n",
            "  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "Collecting git+https://github.com/Baijiong-Lin/LoRA-Torch\n",
            "  Cloning https://github.com/Baijiong-Lin/LoRA-Torch to /tmp/pip-req-build-pdln00a3\n",
            "  Running command git clone --filter=blob:none --quiet https://github.com/Baijiong-Lin/LoRA-Torch /tmp/pip-req-build-pdln00a3\n",
            "  Resolved https://github.com/Baijiong-Lin/LoRA-Torch to commit f2dccb93ac60313d15251fee12e8b7be5649471a\n",
            "  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
          ]
        }
      ],
      "source": [
        "!pip install git+https://github.com/microsoft/LoRA\n",
        "!pip install git+https://github.com/Baijiong-Lin/LoRA-Torch"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import torch, loralib, loratorch, copy\n",
        "import torch.nn as nn"
      ],
      "metadata": {
        "id": "t-KzlZJT81XD"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "model_lib = loralib.MergedLinear(10, 3*10, bias=False, r=4, enable_lora=[True, False, True], fan_in_fan_out=False)\n",
        "loralib.mark_only_lora_as_trainable(model_lib)\n",
        "optimizer = torch.optim.SGD(model_lib.parameters(), lr=0.1)\n",
        "loss_fn = nn.MSELoss()\n",
        "\n",
        "for _ in range(3):\n",
        "    model_lib.train()\n",
        "    x = torch.rand(5, 10)\n",
        "    y = torch.rand(5, 3*10)\n",
        "    loss = loss_fn(model_lib(x), y)\n",
        "    optimizer.zero_grad()\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "\n",
        "    model_lib.eval()\n",
        "    print(model_lib(x).size())"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 366
        },
        "id": "LAyOBzAm82dD",
        "outputId": "bf2541a6-3632-41fe-cb18-e5adc63eb2fb"
      },
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "error",
          "ename": "RuntimeError",
          "evalue": "ignored",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-3-85bb5522e1f5>\u001b[0m in \u001b[0;36m<cell line: 6>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     13\u001b[0m     \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m     \u001b[0mmodel_lib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     16\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_lib\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36meval\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   2305\u001b[0m             \u001b[0mModule\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2306\u001b[0m         \"\"\"\n\u001b[0;32m-> 2307\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2308\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2309\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mrequires_grad_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrequires_grad\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mT\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/loralib/layers.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, mode)\u001b[0m\n\u001b[1;32m    232\u001b[0m                         \u001b[0mgroups\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menable_lora\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    233\u001b[0m                     ).squeeze(0)\n\u001b[0;32m--> 234\u001b[0;31m                     \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_pad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdelta_w\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscaling\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    235\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmerged\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    236\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/loralib/layers.py\u001b[0m in \u001b[0;36mzero_pad\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    203\u001b[0m         \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_zeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    204\u001b[0m         \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 205\u001b[0;31m         result[:, self.lora_ind] = x.reshape(\n\u001b[0m\u001b[1;32m    206\u001b[0m             \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout_features\u001b[0m \u001b[0;34m//\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menable_lora\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menable_lora\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    207\u001b[0m         )\n",
            "\u001b[0;31mRuntimeError\u001b[0m: shape mismatch: value tensor of shape [10, 20] cannot be broadcast to indexing result of shape [20, 20]"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "model_torch = loratorch.MergedLinear(10, 3*10, bias=False, r=4, enable_lora=[True, False, True], fan_in_fan_out=False)\n",
        "loralib.mark_only_lora_as_trainable(model_torch)\n",
        "optimizer = torch.optim.SGD(model_torch.parameters(), lr=0.1)\n",
        "loss_fn = nn.MSELoss()\n",
        "\n",
        "for _ in range(3):\n",
        "    model_torch.train()\n",
        "    x = torch.rand(5, 10)\n",
        "    y = torch.rand(5, 3*10)\n",
        "    loss = loss_fn(model_torch(x), y)\n",
        "    optimizer.zero_grad()\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "\n",
        "    model_torch.eval()\n",
        "    print(model_torch(x).size())"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "GRPPQ-Wk9evZ",
        "outputId": "92684b76-3dc3-47a5-8f8a-a38508850053"
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "torch.Size([5, 30])\n",
            "torch.Size([5, 30])\n",
            "torch.Size([5, 30])\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "jB55l3PU9veg"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}